{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "923cc6d7",
   "metadata": {},
   "source": [
    "# Composite-Adv Demonstration\n",
    "This notebook provides a step-by-step demonstration showing how to launch our composite adversarial attack (CAA). We use the CIFAR-10 dataset for demonstration, while other datasets could be executed similarly.\n",
    "\n",
    "![CAA Flow](figures/CAA_Flow.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4108abaf",
   "metadata": {},
   "source": [
    "## I. Install `composite-adv` Package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e95a053c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "!pip install -q git+https://github.com/IBM/composite-adv.git\n",
    "!pip install -q --upgrade --no-cache-dir gdown # Download pre-trained models from google drive."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d97ef70",
   "metadata": {},
   "source": [
    "## II. Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d574e2a2",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from composite_adv.utilities import make_dataloader\n",
    "data_loader = make_dataloader(dataset_path=\"./data/\", dataset_name=\"cifar10\", batch_size=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05ab346e",
   "metadata": {},
   "source": [
    "## III. Select Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13669ea8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download Pre-trained model\n",
    "# download_gdrive('Google File ID', 'Saving Path')\n",
    "from composite_adv.utilities import download_gdrive\n",
    "download_gdrive('1109eOxG5sSIxCwe_BUKRViMrSF20Ac4c', 'cifar10-resnet_50-gat_fs.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc0689b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# load model\n",
    "from composite_adv.utilities import make_model\n",
    "model = make_model(arch=\"resnet50\", # GAT support two architectures: ['resnet50','wideresnet']\n",
    "                   dataset_name=\"cifar10\", # GAT support three datasets: ['cifar10','svhn','imagenet']\n",
    "                   checkpoint_path=\"cifar10-resnet_50-gat_fs.pt\")\n",
    "\n",
    "# Load Madry's model (https://github.com/MadryLab/robustness)\n",
    "# from composite_adv.utilities import make_madry_model\n",
    "# model = make_madry_model(arch=\"resnet50\",\n",
    "#                          dataset_name=\"cifar10\",\n",
    "#                          checkpoint_path=\"\")\n",
    "\n",
    "\n",
    "# Load TRADES model (https://github.com/yaodongyu/TRADES)\n",
    "# from composite_adv.utilities import make_trades_model\n",
    "# model = make_trades_model(arch=\"wideresnet\",\n",
    "#                           dataset_name=\"cifar10\",\n",
    "#                           checkpoint_path=\"\")\n",
    "\n",
    "\n",
    "# Send to GPU\n",
    "import torch\n",
    "if not torch.cuda.is_available():\n",
    "    print('using CPU, this will be slow')\n",
    "else:\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3848fb6f",
   "metadata": {},
   "source": [
    "## IV. Evaluate Clean Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe6af91",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from composite_adv.utilities import robustness_evaluate\n",
    "from composite_adv.attacks import NoAttack\n",
    "\n",
    "attack = NoAttack()\n",
    "robustness_evaluate(model, attack, data_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7155df46",
   "metadata": {},
   "source": [
    "## V. Evaluate Robust Accuracy\n",
    "\n",
    "**CAA Configuration**\n",
    "1. Attacks Pool Selection. For simpilicity, we use the following abbreviations to specify each attack types.\n",
    "   `0`: Hue, `1`: Saturation, `2`: Rotation, `3`: Brightness, `4`: Contrast, `5`: $\\ell_\\infty$\n",
    "\n",
    "2. Attack Ordering Specify. We provide three ordering options ['fixed','random','scheduled']\n",
    "\n",
    "**Setup**\n",
    "```python\n",
    "# Specify Attack\n",
    "\n",
    "from composite_adv.attacks import CompositeAttack\n",
    "# Three Attacks (Hue->Saturation->Rotation; Fixed Order)\n",
    "attack = CompositeAttack(model, dataset=\"cifar10\", enabled_attack=(0,1,2), order_schedule=\"fixed\")\n",
    "# Semantic Attacks; Random Order\n",
    "attack = CompositeAttack(model, dataset=\"cifar10\", enabled_attack=(0,1,2,3,4), order_schedule=\"random\")\n",
    "# Full Attacks; Scheduled Order\n",
    "attack = CompositeAttack(model, dataset=\"cifar10\", enabled_attack=(0,1,2,3,4,5), order_schedule=\"scheduled\") \n",
    "\n",
    "# Model Evaluation\n",
    "from composite_adv.utilities import robustness_evaluate\n",
    "robustness_evaluate(model, attack, data_loader)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5211deba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from composite_adv.attacks import CompositeAttack\n",
    "# Full Attacks; Scheduled Order\n",
    "attack = CompositeAttack(model, dataset=\"cifar10\", enabled_attack=(0,1,2,3,4,5), order_schedule=\"scheduled\")\n",
    "\n",
    "from composite_adv.utilities import robustness_evaluate\n",
    "robust_accuracy, attack_success_rate = robustness_evaluate(model, attack, data_loader)\n",
    "print(\"Robust Accuracy:\", robust_accuracy)\n",
    "print(\"Attack Success Rate:\", attack_success_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6887f8f2",
   "metadata": {},
   "source": [
    "## VI. Visualize CAA examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56456498",
   "metadata": {},
   "outputs": [],
   "source": [
    "from composite_adv.attacks import CompositeAttack\n",
    "import torchvision\n",
    "attack = CompositeAttack(model, enabled_attack=(0,1,2,3,4,5), order_schedule=\"scheduled\")\n",
    "\n",
    "def imgshow(img):\n",
    "    import matplotlib.pyplot as plt\n",
    "    import numpy as np\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dbe842f",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs, labels = next(iter(data_loader))\n",
    "ori_images, ori_labels = inputs[:5].cuda(), labels[:5].cuda()\n",
    "adv_images = attack(ori_images, ori_labels)\n",
    "ori_grid = torchvision.utils.make_grid(ori_images.cpu(), nrow=5, padding=1)\n",
    "adv_grid = torchvision.utils.make_grid(adv_images.cpu(), nrow=5, padding=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18d7846e",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgshow(ori_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ccba50d",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgshow(adv_grid)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
